{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.image import load_img, save_img, img_to_array\n",
    "import numpy as np\n",
    "from scipy.optimize import fmin_l_bfgs_b\n",
    "import time\n",
    "from keras.applications import vgg19\n",
    "from keras.applications.imagenet_utils import preprocess_input\n",
    "from keras import backend as K\n",
    "import tensorflow as tf\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_image_path = 'C:/Users/IS96273/Desktop/content.jpg'\n",
    "style_reference_image_path = 'C:/Users/IS96273/Desktop/style.jpg'\n",
    "\n",
    "iterations = 1\n",
    "\n",
    "img_nrows = 400; img_ncols = 534"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_image(image_path):\n",
    "    img = load_img(image_path, target_size=(img_nrows, img_ncols))\n",
    "    img = img_to_array(img)\n",
    "    img = np.expand_dims(img, axis=0)\n",
    "    img = preprocess_input(img)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def deprocess_image(x):\n",
    "    x = x.reshape((img_nrows, img_ncols, 3))\n",
    "    # Remove zero-center by mean pixel\n",
    "    x[:, :, 0] += 103.939\n",
    "    x[:, :, 1] += 116.779\n",
    "    x[:, :, 2] += 123.68\n",
    "    # 'BGR'->'RGB'\n",
    "    x = x[:, :, ::-1]\n",
    "    x = np.clip(x, 0, 255).astype('uint8')\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_image = K.variable(preprocess_image(base_image_path))\n",
    "x = K.variable(preprocess_image(base_image_path))\n",
    "style_reference_image = K.variable(preprocess_image(style_reference_image_path))\n",
    "\n",
    "random_pixels = np.random.randint(256, size=(img_nrows, img_ncols, 3))\n",
    "combination_image = preprocess_input(np.expand_dims(random_pixels, axis=0))\n",
    "combination_image = K.variable(combination_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "content_model = vgg19.VGG19(input_tensor=base_image, weights='imagenet', include_top=False)\n",
    "style_model = vgg19.VGG19(input_tensor=style_reference_image, weights='imagenet', include_top=False)\n",
    "\n",
    "content_outputs = dict([(layer.name, layer.output) for layer in content_model.layers])\n",
    "style_outputs = dict([(layer.name, layer.output) for layer in style_model.layers])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_model = vgg19.VGG19(input_tensor=combination_image, weights='imagenet', include_top=False)\n",
    "generated_outputs = dict([(layer.name, layer.output) for layer in generated_model.layers])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gram_matrix(x):\n",
    "    features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))\n",
    "    gram = K.dot(features, K.transpose(features))\n",
    "    return gram\n",
    "\n",
    "def style_loss(style, combination):\n",
    "    S = gram_matrix(style)\n",
    "    C = gram_matrix(combination)\n",
    "    channels = 3\n",
    "    size = img_nrows * img_ncols\n",
    "    return K.sum(K.square(S - C)) / (4. * (pow(channels,2)) * (pow(size,2)))\n",
    "\n",
    "def content_loss(base, combination):\n",
    "    return K.sum(K.square(combination - base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = K.variable(0)\n",
    "\n",
    "base_image_features = content_outputs['block5_conv2'][0]\n",
    "combination_features = generated_outputs['block5_conv2'][0]\n",
    "contentloss = content_loss(base_image_features, combination_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']\n",
    "\n",
    "styleloss = K.variable(0)\n",
    "\n",
    "for layer_name in feature_layers:\n",
    "    style_reference_features = style_outputs[layer_name][0]\n",
    "    combination_features = generated_outputs[layer_name][0]\n",
    "    styleloss = styleloss + style_loss(style_reference_features, combination_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.025; beta = 0.2\n",
    "loss = alpha * contentloss + beta * styleloss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "grads = K.gradients(loss, combination_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nif isinstance(grads, (list, tuple)):\\n    outputs += grads\\nelse:\\n    outputs.append(grads)\\n'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = [loss]\n",
    "outputs += grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_outputs = K.function([combination_image], outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_loss_and_grads(x):\n",
    "    x = x.reshape((1, img_nrows, img_ncols, 3))\n",
    "    outs = f_outputs([x])\n",
    "    loss_value = outs[0]\n",
    "    if len(outs[1:]) == 1:\n",
    "        grad_values = outs[1].flatten().astype('float64')\n",
    "    else:\n",
    "        grad_values = np.array(outs[1:]).flatten().astype('float64')\n",
    "    return loss_value, grad_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Evaluator(object):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.loss_value = None\n",
    "        self.grads_values = None\n",
    "\n",
    "    def loss(self, x):\n",
    "        assert self.loss_value is None\n",
    "        loss_value, grad_values = eval_loss_and_grads(x)\n",
    "        self.loss_value = loss_value\n",
    "        self.grad_values = grad_values\n",
    "        return self.loss_value\n",
    "\n",
    "    def grads(self, x):\n",
    "        assert self.loss_value is not None\n",
    "        grad_values = np.copy(self.grad_values)\n",
    "        self.loss_value = None\n",
    "        self.grad_values = None\n",
    "        return grad_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = Evaluator()\n",
    "\n",
    "x = preprocess_image(base_image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch  0\n"
     ]
    }
   ],
   "source": [
    "for i in range(0,iterations):\n",
    "    print(\"epoch \",i)\n",
    "    x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=20)\n",
    "    img = deprocess_image(x.copy())\n",
    "    fname = 'C:/Users/IS96273/Desktop/generated_%d.png' % i\n",
    "    save_img(fname, img)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
